{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "no_sessions = 3\n",
    "no_participants = 15\n",
    "no_channels = 62\n",
    "no_features = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from python import norm_functions\n",
    "from python import nn\n",
    "from python import nn_batch\n",
    "from python import nn_stratified\n",
    "from python import train_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x10ea9e070>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <a id='two_cat'>1. Two categories: Positive and negative</a> "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "no_videos = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(675,)\n"
     ]
    }
   ],
   "source": [
    "labels_ = np.load('./data/emotion_labels.npy')\n",
    "print(labels_.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(675,)\n"
     ]
    }
   ],
   "source": [
    "participants_sessions_vector_ = np.load('./data/participants_sessions_vector.npy')\n",
    "print(participants_sessions_vector_.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(450,)\n"
     ]
    }
   ],
   "source": [
    "index_two_classes = []\n",
    "for i in range(len(labels_)):\n",
    "    if labels_[i]!=1:\n",
    "        index_two_classes.append(i)\n",
    "print(np.shape(index_two_classes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(450,)\n",
      "(450,)\n"
     ]
    }
   ],
   "source": [
    "labels = labels_[index_two_classes]\n",
    "print(labels.shape)\n",
    "\n",
    "participants_sessions_vector = participants_sessions_vector_[index_two_classes]\n",
    "print(participants_sessions_vector.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <a id='two_cat_batch'>1.1. Two categories: NN with batch normalization</a> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(450, 248)\n"
     ]
    }
   ],
   "source": [
    "bandpower_SEED_ = np.load('./data/bandpower_SEED_de.npy')\n",
    "bandpower_SEED = bandpower_SEED_[index_two_classes]\n",
    "bandpower_SEED = norm_functions.normalization(bandpower_SEED, no_videos=10)\n",
    "print(bandpower_SEED.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Cross-subject NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Participant  1\n",
      "NN no norm:  0.6000000238418579\n",
      "NN norm:  0.8666666746139526 \n",
      "\n",
      "Participant  2\n",
      "NN no norm:  0.699999988079071\n",
      "NN norm:  0.699999988079071 \n",
      "\n",
      "Participant  3\n",
      "NN no norm:  0.800000011920929\n",
      "NN norm:  0.7666666507720947 \n",
      "\n",
      "Participant  4\n",
      "NN no norm:  0.7666666507720947\n",
      "NN norm:  0.8333333134651184 \n",
      "\n",
      "Participant  5\n",
      "NN no norm:  0.6333333253860474\n",
      "NN norm:  0.8999999761581421 \n",
      "\n",
      "Participant  6\n",
      "NN no norm:  0.9666666388511658\n",
      "NN norm:  0.8999999761581421 \n",
      "\n",
      "Participant  7\n",
      "NN no norm:  0.5333333611488342\n",
      "NN norm:  0.9333333373069763 \n",
      "\n",
      "Participant  8\n",
      "NN no norm:  0.5\n",
      "NN norm:  0.9666666388511658 \n",
      "\n",
      "Participant  9\n",
      "NN no norm:  0.6333333253860474\n",
      "NN norm:  0.8999999761581421 \n",
      "\n",
      "Participant  10\n",
      "NN no norm:  0.8333333134651184\n",
      "NN norm:  0.9666666388511658 \n",
      "\n",
      "Participant  11\n",
      "NN no norm:  0.9666666388511658\n",
      "NN norm:  0.9333333373069763 \n",
      "\n",
      "Participant  12\n",
      "NN no norm:  0.6666666865348816\n",
      "NN norm:  0.6333333253860474 \n",
      "\n",
      "Participant  13\n",
      "NN no norm:  0.5666666626930237\n",
      "NN norm:  0.9666666388511658 \n",
      "\n",
      "Participant  14\n",
      "NN no norm:  0.800000011920929\n",
      "NN norm:  0.8666666746139526 \n",
      "\n",
      "Participant  15\n",
      "NN no norm:  0.8999999761581421\n",
      "NN norm:  1.0 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "no_videos_session = no_sessions*no_videos\n",
    "nn_nonorm = []\n",
    "nn_norm = []\n",
    "\n",
    "for i in range(no_participants):\n",
    "    \n",
    "    if i==0:\n",
    "        train_x_cross_subject = bandpower_SEED[1*no_videos_session:]\n",
    "        val_x_cross_subject = bandpower_SEED[0:1*no_videos_session]\n",
    "\n",
    "        train_y_cross_subject = labels[1*no_videos_session:]\n",
    "        val_y_cross_subject = labels[0:1*no_videos_session]\n",
    "\n",
    "        train_i_cross_subject = participants_sessions_vector[1*no_videos_session:]\n",
    "        val_i_cross_subject = participants_sessions_vector[0:1*no_videos_session]\n",
    "        \n",
    "    elif i==(no_participants-1):\n",
    "        train_x_cross_subject = bandpower_SEED[:14*no_videos_session]\n",
    "        val_x_cross_subject = bandpower_SEED[14*no_videos_session:]\n",
    "\n",
    "        train_y_cross_subject = labels[:14*no_videos_session]\n",
    "        val_y_cross_subject = labels[14*no_videos_session:]\n",
    "\n",
    "        train_i_cross_subject = participants_sessions_vector[:14*no_videos_session]\n",
    "        val_i_cross_subject = participants_sessions_vector[14*no_videos_session:]\n",
    "                \n",
    "    else:\n",
    "        train_x_cross_subject = np.concatenate((bandpower_SEED[0:i*no_videos_session,:], bandpower_SEED[(i+1)*no_videos_session:,:]))\n",
    "        val_x_cross_subject = bandpower_SEED[i*no_videos_session:(i+1)*no_videos_session]\n",
    "\n",
    "        train_y_cross_subject = np.concatenate((labels[0:i*no_videos_session], labels[(i+1)*no_videos_session:]))\n",
    "        val_y_cross_subject = labels[i*no_videos_session:(i+1)*no_videos_session]\n",
    "\n",
    "        train_i_cross_subject = np.concatenate((participants_sessions_vector[0:i*no_videos_session], participants_sessions_vector[(i+1)*no_videos_session:]))\n",
    "        val_i_cross_subject = participants_sessions_vector[i*no_videos_session:(i+1)*no_videos_session]          \n",
    "\n",
    "      \n",
    "    ## NN nonorm\n",
    "    netNoNorm = nn.net()   \n",
    "    ts_acc = train_model.train_model_cross_subject(model=netNoNorm, \n",
    "                                                  train_x=train_x_cross_subject, \n",
    "                                                  test_x=val_x_cross_subject, \n",
    "                                                  train_y=train_y_cross_subject, \n",
    "                                                  test_y=val_y_cross_subject, \n",
    "                                                  train_i=train_i_cross_subject,\n",
    "                                                  test_i=val_i_cross_subject,\n",
    "                                                  no_epochs = 100,\n",
    "                                                  normalize=False)\n",
    "    nn_nonorm.append(ts_acc)\n",
    "\n",
    "    ## NN norm\n",
    "    netNorm = nn_batch.net_batch_norm()    \n",
    "    ts_acc = train_model.train_model_cross_subject(model=netNorm, \n",
    "                                                  train_x=train_x_cross_subject, \n",
    "                                                  test_x=val_x_cross_subject, \n",
    "                                                  train_y=train_y_cross_subject, \n",
    "                                                  test_y=val_y_cross_subject, \n",
    "                                                  train_i=train_i_cross_subject,\n",
    "                                                  test_i=val_i_cross_subject,\n",
    "                                                  no_epochs = 100,\n",
    "                                                  normalize=True)\n",
    "    nn_norm.append(ts_acc)\n",
    "    \n",
    "    print('Participant ', (i+1))\n",
    "    print('NN no norm: ', nn_nonorm[i])\n",
    "    print('NN norm: ', nn_norm[i], '\\n')\n",
    "    \n",
    "nn_nonorm = np.array(nn_nonorm)\n",
    "nn_norm = np.array(nn_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NN no norm mean:  0.724\n",
      "NN no norm std:  0.146\n",
      "NN norm mean:  0.876\n",
      "NN norm std:  0.101\n"
     ]
    }
   ],
   "source": [
    "print('NN no norm mean: ', str(round(np.mean(nn_nonorm), 3)))\n",
    "print('NN no norm std: ', str(round(np.std(nn_nonorm), 3)))\n",
    "print('NN norm mean: ', str(round(np.mean(nn_norm), 3)))\n",
    "print('NN norm std: ', str(round(np.std(nn_norm), 3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <a id='two_cat_stratified'>1.2. Two categories: NN with stratified normalization</a> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(450, 248)\n"
     ]
    }
   ],
   "source": [
    "bandpower_SEED_ = np.load('./data/bandpower_SEED_de.npy')\n",
    "bandpower_SEED = bandpower_SEED_[index_two_classes]\n",
    "bandpower_SEED = norm_functions.normalization_per_participant_session(bandpower_SEED, no_videos=10)\n",
    "print(bandpower_SEED.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Cross-subject NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Participant  1\n",
      "NN no norm:  0.7333333492279053\n",
      "NN norm:  0.7666666507720947 \n",
      "\n",
      "Participant  2\n",
      "NN no norm:  0.8333333134651184\n",
      "NN norm:  0.800000011920929 \n",
      "\n",
      "Participant  3\n",
      "NN no norm:  0.8999999761581421\n",
      "NN norm:  0.8666666746139526 \n",
      "\n",
      "Participant  4\n",
      "NN no norm:  0.8999999761581421\n",
      "NN norm:  0.8999999761581421 \n",
      "\n",
      "Participant  5\n",
      "NN no norm:  0.8666666746139526\n",
      "NN norm:  1.0 \n",
      "\n",
      "Participant  6\n",
      "NN no norm:  0.8999999761581421\n",
      "NN norm:  0.9333333373069763 \n",
      "\n",
      "Participant  7\n",
      "NN no norm:  0.7666666507720947\n",
      "NN norm:  0.8333333134651184 \n",
      "\n",
      "Participant  8\n",
      "NN no norm:  0.9666666388511658\n",
      "NN norm:  1.0 \n",
      "\n",
      "Participant  9\n",
      "NN no norm:  0.800000011920929\n",
      "NN norm:  0.8999999761581421 \n",
      "\n",
      "Participant  10\n",
      "NN no norm:  0.8666666746139526\n",
      "NN norm:  0.9666666388511658 \n",
      "\n",
      "Participant  11\n",
      "NN no norm:  1.0\n",
      "NN norm:  0.9333333373069763 \n",
      "\n",
      "Participant  12\n",
      "NN no norm:  0.6000000238418579\n",
      "NN norm:  0.8999999761581421 \n",
      "\n",
      "Participant  13\n",
      "NN no norm:  0.8999999761581421\n",
      "NN norm:  0.9333333373069763 \n",
      "\n",
      "Participant  14\n",
      "NN no norm:  0.7666666507720947\n",
      "NN norm:  0.8666666746139526 \n",
      "\n",
      "Participant  15\n",
      "NN no norm:  1.0\n",
      "NN norm:  1.0 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "no_videos_session = no_sessions*no_videos\n",
    "nn_nonorm = []\n",
    "nn_norm = []\n",
    "\n",
    "for i in range(no_participants):\n",
    "    \n",
    "    if i==0:\n",
    "        train_x_cross_subject = bandpower_SEED[1*no_videos_session:]\n",
    "        val_x_cross_subject = bandpower_SEED[0:1*no_videos_session]\n",
    "\n",
    "        train_y_cross_subject = labels[1*no_videos_session:]\n",
    "        val_y_cross_subject = labels[0:1*no_videos_session]\n",
    "\n",
    "        train_i_cross_subject = participants_sessions_vector[1*no_videos_session:]\n",
    "        val_i_cross_subject = participants_sessions_vector[0:1*no_videos_session]\n",
    "        \n",
    "    elif i==(no_participants-1):\n",
    "        train_x_cross_subject = bandpower_SEED[:14*no_videos_session]\n",
    "        val_x_cross_subject = bandpower_SEED[14*no_videos_session:]\n",
    "\n",
    "        train_y_cross_subject = labels[:14*no_videos_session]\n",
    "        val_y_cross_subject = labels[14*no_videos_session:]\n",
    "\n",
    "        train_i_cross_subject = participants_sessions_vector[:14*no_videos_session]\n",
    "        val_i_cross_subject = participants_sessions_vector[14*no_videos_session:]\n",
    "                \n",
    "    else:\n",
    "        train_x_cross_subject = np.concatenate((bandpower_SEED[0:i*no_videos_session,:], bandpower_SEED[(i+1)*no_videos_session:,:]))\n",
    "        val_x_cross_subject = bandpower_SEED[i*no_videos_session:(i+1)*no_videos_session]\n",
    "\n",
    "        train_y_cross_subject = np.concatenate((labels[0:i*no_videos_session], labels[(i+1)*no_videos_session:]))\n",
    "        val_y_cross_subject = labels[i*no_videos_session:(i+1)*no_videos_session]\n",
    "\n",
    "        train_i_cross_subject = np.concatenate((participants_sessions_vector[0:i*no_videos_session], participants_sessions_vector[(i+1)*no_videos_session:]))\n",
    "        val_i_cross_subject = participants_sessions_vector[i*no_videos_session:(i+1)*no_videos_session]       \n",
    "\n",
    "      \n",
    "    ## NN nonorm\n",
    "    netNoNorm = nn.net()   \n",
    "    ts_acc = train_model.train_model_cross_subject(model=netNoNorm, \n",
    "                                                  train_x=train_x_cross_subject, \n",
    "                                                  test_x=val_x_cross_subject, \n",
    "                                                  train_y=train_y_cross_subject, \n",
    "                                                  test_y=val_y_cross_subject, \n",
    "                                                  train_i=train_i_cross_subject,\n",
    "                                                  test_i=val_i_cross_subject,\n",
    "                                                  no_epochs = 100,\n",
    "                                                  normalize=False)\n",
    "    nn_nonorm.append(ts_acc)\n",
    "\n",
    "    ## NN norm\n",
    "    netNorm = nn_stratified.net_stratified_norm()    \n",
    "    ts_acc = train_model.train_model_cross_subject(model=netNorm, \n",
    "                                                  train_x=train_x_cross_subject, \n",
    "                                                  test_x=val_x_cross_subject, \n",
    "                                                  train_y=train_y_cross_subject, \n",
    "                                                  test_y=val_y_cross_subject, \n",
    "                                                  train_i=train_i_cross_subject,\n",
    "                                                  test_i=val_i_cross_subject,\n",
    "                                                  no_epochs = 100,\n",
    "                                                  normalize=True)\n",
    "    nn_norm.append(ts_acc)\n",
    "    \n",
    "    print('Participant ', (i+1))\n",
    "    print('NN no norm: ', nn_nonorm[i])\n",
    "    print('NN norm: ', nn_norm[i], '\\n')\n",
    "    \n",
    "nn_nonorm = np.array(nn_nonorm)\n",
    "nn_norm = np.array(nn_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NN no norm mean:  0.853\n",
      "NN no norm std:  0.104\n",
      "NN norm mean:  0.907\n",
      "NN norm std:  0.069\n"
     ]
    }
   ],
   "source": [
    "print('NN no norm mean: ', str(round(np.mean(nn_nonorm), 3)))\n",
    "print('NN no norm std: ', str(round(np.std(nn_nonorm), 3)))\n",
    "print('NN norm mean: ', str(round(np.mean(nn_norm), 3)))\n",
    "print('NN norm std: ', str(round(np.std(nn_norm), 3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## <a id='three_cat'>2. Three categories: Positive, neutral and negative</a> "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "no_videos = 15"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(675,)\n",
      "[2. 1. 0. 0. 1. 2. 0. 1. 2. 2. 1. 0. 1. 2. 0.]\n"
     ]
    }
   ],
   "source": [
    "labels = np.load('./data/emotion_labels.npy')\n",
    "print(labels.shape)\n",
    "print(labels[0:15])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(675,)\n"
     ]
    }
   ],
   "source": [
    "participants_sessions_vector = np.load('./data/participants_sessions_vector.npy')\n",
    "print(participants_sessions_vector.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <a id='three_cat_batch'>2.1. Three categories: NN with batch normalization</a> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(675, 248)\n"
     ]
    }
   ],
   "source": [
    "bandpower_SEED = np.load('./data/bandpower_SEED_de.npy')\n",
    "bandpower_SEED = norm_functions.normalization(bandpower_SEED, no_videos=15)\n",
    "print(bandpower_SEED.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Cross-subject NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Participant  1\n",
      "NN no norm:  0.4888888895511627\n",
      "NN norm:  0.5333333611488342 \n",
      "\n",
      "Participant  2\n",
      "NN no norm:  0.3777777850627899\n",
      "NN norm:  0.4888888895511627 \n",
      "\n",
      "Participant  3\n",
      "NN no norm:  0.6888889074325562\n",
      "NN norm:  0.644444465637207 \n",
      "\n",
      "Participant  4\n",
      "NN no norm:  0.5777778029441833\n",
      "NN norm:  0.6888889074325562 \n",
      "\n",
      "Participant  5\n",
      "NN no norm:  0.5111111402511597\n",
      "NN norm:  0.7333333492279053 \n",
      "\n",
      "Participant  6\n",
      "NN no norm:  0.6222222447395325\n",
      "NN norm:  0.7333333492279053 \n",
      "\n",
      "Participant  7\n",
      "NN no norm:  0.4000000059604645\n",
      "NN norm:  0.5777778029441833 \n",
      "\n",
      "Participant  8\n",
      "NN no norm:  0.4444444477558136\n",
      "NN norm:  0.8222222328186035 \n",
      "\n",
      "Participant  9\n",
      "NN no norm:  0.5111111402511597\n",
      "NN norm:  0.644444465637207 \n",
      "\n",
      "Participant  10\n",
      "NN no norm:  0.7333333492279053\n",
      "NN norm:  0.644444465637207 \n",
      "\n",
      "Participant  11\n",
      "NN no norm:  0.8222222328186035\n",
      "NN norm:  0.800000011920929 \n",
      "\n",
      "Participant  12\n",
      "NN no norm:  0.5111111402511597\n",
      "NN norm:  0.4888888895511627 \n",
      "\n",
      "Participant  13\n",
      "NN no norm:  0.3777777850627899\n",
      "NN norm:  0.6666666865348816 \n",
      "\n",
      "Participant  14\n",
      "NN no norm:  0.6000000238418579\n",
      "NN norm:  0.644444465637207 \n",
      "\n",
      "Participant  15\n",
      "NN no norm:  0.800000011920929\n",
      "NN norm:  0.8888888955116272 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "no_videos_session = no_sessions*no_videos\n",
    "nn_nonorm = []\n",
    "nn_norm = []\n",
    "\n",
    "for i in range(no_participants):\n",
    "    \n",
    "    if i==0:\n",
    "        train_x_cross_subject = bandpower_SEED[1*no_videos_session:]\n",
    "        val_x_cross_subject = bandpower_SEED[0:1*no_videos_session]\n",
    "\n",
    "        train_y_cross_subject = labels[1*no_videos_session:]\n",
    "        val_y_cross_subject = labels[0:1*no_videos_session]\n",
    "\n",
    "        train_i_cross_subject = participants_sessions_vector[1*no_videos_session:]\n",
    "        val_i_cross_subject = participants_sessions_vector[0:1*no_videos_session]\n",
    "        \n",
    "    elif i==(no_participants-1):\n",
    "        train_x_cross_subject = bandpower_SEED[:14*no_videos_session]\n",
    "        val_x_cross_subject = bandpower_SEED[14*no_videos_session:]\n",
    "\n",
    "        train_y_cross_subject = labels[:14*no_videos_session]\n",
    "        val_y_cross_subject = labels[14*no_videos_session:]\n",
    "\n",
    "        train_i_cross_subject = participants_sessions_vector[:14*no_videos_session]\n",
    "        val_i_cross_subject = participants_sessions_vector[14*no_videos_session:]\n",
    "                \n",
    "    else:\n",
    "        train_x_cross_subject = np.concatenate((bandpower_SEED[0:i*no_videos_session,:], bandpower_SEED[(i+1)*no_videos_session:,:]))\n",
    "        val_x_cross_subject = bandpower_SEED[i*no_videos_session:(i+1)*no_videos_session]\n",
    "\n",
    "        train_y_cross_subject = np.concatenate((labels[0:i*no_videos_session], labels[(i+1)*no_videos_session:]))\n",
    "        val_y_cross_subject = labels[i*no_videos_session:(i+1)*no_videos_session]\n",
    "\n",
    "        train_i_cross_subject = np.concatenate((participants_sessions_vector[0:i*no_videos_session], participants_sessions_vector[(i+1)*no_videos_session:]))\n",
    "        val_i_cross_subject = participants_sessions_vector[i*no_videos_session:(i+1)*no_videos_session]        \n",
    "\n",
    "      \n",
    "    ## NN nonorm\n",
    "    netNoNorm = nn.net()   \n",
    "    ts_acc = train_model.train_model_cross_subject(model=netNoNorm, \n",
    "                                                  train_x=train_x_cross_subject, \n",
    "                                                  test_x=val_x_cross_subject, \n",
    "                                                  train_y=train_y_cross_subject, \n",
    "                                                  test_y=val_y_cross_subject, \n",
    "                                                  train_i=train_i_cross_subject,\n",
    "                                                  test_i=val_i_cross_subject,\n",
    "                                                  no_epochs = 100,\n",
    "                                                  normalize=False)\n",
    "    nn_nonorm.append(ts_acc)\n",
    "\n",
    "    ## NN norm\n",
    "    netNorm = nn_batch.net_batch_norm()    \n",
    "    ts_acc = train_model.train_model_cross_subject(model=netNorm, \n",
    "                                                  train_x=train_x_cross_subject, \n",
    "                                                  test_x=val_x_cross_subject, \n",
    "                                                  train_y=train_y_cross_subject, \n",
    "                                                  test_y=val_y_cross_subject, \n",
    "                                                  train_i=train_i_cross_subject,\n",
    "                                                  test_i=val_i_cross_subject,\n",
    "                                                  no_epochs = 100,\n",
    "                                                  normalize=True)\n",
    "    nn_norm.append(ts_acc)\n",
    "    \n",
    "    print('Participant ', (i+1))\n",
    "    print('NN no norm: ', nn_nonorm[i])\n",
    "    print('NN norm: ', nn_norm[i], '\\n')\n",
    "    \n",
    "nn_nonorm = np.array(nn_nonorm)\n",
    "nn_norm = np.array(nn_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NN no norm mean:  0.564\n",
      "NN no norm std:  0.14\n",
      "NN norm mean:  0.667\n",
      "NN norm std:  0.113\n"
     ]
    }
   ],
   "source": [
    "print('NN no norm mean: ', str(round(np.mean(nn_nonorm), 3)))\n",
    "print('NN no norm std: ', str(round(np.std(nn_nonorm), 3)))\n",
    "print('NN norm mean: ', str(round(np.mean(nn_norm), 3)))\n",
    "print('NN norm std: ', str(round(np.std(nn_norm), 3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <a id='three_cat_stratified'>2.2. Three categories: NN with stratified normalization</a> "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(675, 248)\n"
     ]
    }
   ],
   "source": [
    "bandpower_SEED = np.load('./data/bandpower_SEED_de.npy')\n",
    "bandpower_SEED = norm_functions.normalization_per_participant_session(bandpower_SEED, no_videos=15)\n",
    "print(bandpower_SEED.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Cross-subject NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Participant  1\n",
      "NN no norm:  0.6222222447395325\n",
      "NN norm:  0.6666666865348816 \n",
      "\n",
      "Participant  2\n",
      "NN no norm:  0.6666666865348816\n",
      "NN norm:  0.6222222447395325 \n",
      "\n",
      "Participant  3\n",
      "NN no norm:  0.8888888955116272\n",
      "NN norm:  0.9111111164093018 \n",
      "\n",
      "Participant  4\n",
      "NN no norm:  0.7777777910232544\n",
      "NN norm:  0.8444444537162781 \n",
      "\n",
      "Participant  5\n",
      "NN no norm:  0.7777777910232544\n",
      "NN norm:  0.8888888955116272 \n",
      "\n",
      "Participant  6\n",
      "NN no norm:  0.7333333492279053\n",
      "NN norm:  0.7555555701255798 \n",
      "\n",
      "Participant  7\n",
      "NN no norm:  0.6666666865348816\n",
      "NN norm:  0.7111111283302307 \n",
      "\n",
      "Participant  8\n",
      "NN no norm:  0.9111111164093018\n",
      "NN norm:  0.8888888955116272 \n",
      "\n",
      "Participant  9\n",
      "NN no norm:  0.6666666865348816\n",
      "NN norm:  0.8666666746139526 \n",
      "\n",
      "Participant  10\n",
      "NN no norm:  0.7111111283302307\n",
      "NN norm:  0.7333333492279053 \n",
      "\n",
      "Participant  11\n",
      "NN no norm:  0.8222222328186035\n",
      "NN norm:  0.800000011920929 \n",
      "\n",
      "Participant  12\n",
      "NN no norm:  0.5111111402511597\n",
      "NN norm:  0.800000011920929 \n",
      "\n",
      "Participant  13\n",
      "NN no norm:  0.6000000238418579\n",
      "NN norm:  0.6000000238418579 \n",
      "\n",
      "Participant  14\n",
      "NN no norm:  0.6888889074325562\n",
      "NN norm:  0.6666666865348816 \n",
      "\n",
      "Participant  15\n",
      "NN no norm:  1.0\n",
      "NN norm:  0.9333333373069763 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "no_videos_session = no_sessions*no_videos\n",
    "nn_nonorm = []\n",
    "nn_norm = []\n",
    "\n",
    "for i in range(no_participants):\n",
    "    \n",
    "    if i==0:\n",
    "        train_x_cross_subject = bandpower_SEED[1*no_videos_session:]\n",
    "        val_x_cross_subject = bandpower_SEED[0:1*no_videos_session]\n",
    "\n",
    "        train_y_cross_subject = labels[1*no_videos_session:]\n",
    "        val_y_cross_subject = labels[0:1*no_videos_session]\n",
    "\n",
    "        train_i_cross_subject = participants_sessions_vector[1*no_videos_session:]\n",
    "        val_i_cross_subject = participants_sessions_vector[0:1*no_videos_session]\n",
    "        \n",
    "    elif i==(no_participants-1):\n",
    "        train_x_cross_subject = bandpower_SEED[:14*no_videos_session]\n",
    "        val_x_cross_subject = bandpower_SEED[14*no_videos_session:]\n",
    "\n",
    "        train_y_cross_subject = labels[:14*no_videos_session]\n",
    "        val_y_cross_subject = labels[14*no_videos_session:]\n",
    "\n",
    "        train_i_cross_subject = participants_sessions_vector[:14*no_videos_session]\n",
    "        val_i_cross_subject = participants_sessions_vector[14*no_videos_session:]\n",
    "                \n",
    "    else:\n",
    "        train_x_cross_subject = np.concatenate((bandpower_SEED[0:i*no_videos_session,:], bandpower_SEED[(i+1)*no_videos_session:,:]))\n",
    "        val_x_cross_subject = bandpower_SEED[i*no_videos_session:(i+1)*no_videos_session]\n",
    "\n",
    "        train_y_cross_subject = np.concatenate((labels[0:i*no_videos_session], labels[(i+1)*no_videos_session:]))\n",
    "        val_y_cross_subject = labels[i*no_videos_session:(i+1)*no_videos_session]\n",
    "\n",
    "        train_i_cross_subject = np.concatenate((participants_sessions_vector[0:i*no_videos_session], participants_sessions_vector[(i+1)*no_videos_session:]))\n",
    "        val_i_cross_subject = participants_sessions_vector[i*no_videos_session:(i+1)*no_videos_session]    \n",
    "\n",
    "      \n",
    "    ## NN nonorm\n",
    "    netNoNorm = nn.net()   \n",
    "    ts_acc = train_model.train_model_cross_subject(model=netNoNorm, \n",
    "                                                  train_x=train_x_cross_subject, \n",
    "                                                  test_x=val_x_cross_subject, \n",
    "                                                  train_y=train_y_cross_subject, \n",
    "                                                  test_y=val_y_cross_subject, \n",
    "                                                  train_i=train_i_cross_subject,\n",
    "                                                  test_i=val_i_cross_subject,\n",
    "                                                  no_epochs = 100,\n",
    "                                                  normalize=False)\n",
    "    nn_nonorm.append(ts_acc)\n",
    "\n",
    "    ## NN norm\n",
    "    netNorm = nn_stratified.net_stratified_norm()    \n",
    "    ts_acc = train_model.train_model_cross_subject(model=netNorm, \n",
    "                                                  train_x=train_x_cross_subject, \n",
    "                                                  test_x=val_x_cross_subject, \n",
    "                                                  train_y=train_y_cross_subject, \n",
    "                                                  test_y=val_y_cross_subject, \n",
    "                                                  train_i=train_i_cross_subject,\n",
    "                                                  test_i=val_i_cross_subject,\n",
    "                                                  no_epochs = 100,\n",
    "                                                  normalize=True)\n",
    "    nn_norm.append(ts_acc)\n",
    "    \n",
    "    print('Participant ', (i+1))\n",
    "    print('NN no norm: ', nn_nonorm[i])\n",
    "    print('NN norm: ', nn_norm[i], '\\n')\n",
    "    \n",
    "nn_nonorm = np.array(nn_nonorm)\n",
    "nn_norm = np.array(nn_norm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NN no norm mean:  0.736\n",
      "NN no norm std:  0.125\n",
      "NN norm mean:  0.779\n",
      "NN norm std:  0.106\n"
     ]
    }
   ],
   "source": [
    "print('NN no norm mean: ', str(round(np.mean(nn_nonorm), 3)))\n",
    "print('NN no norm std: ', str(round(np.std(nn_nonorm), 3)))\n",
    "print('NN norm mean: ', str(round(np.mean(nn_norm), 3)))\n",
    "print('NN norm std: ', str(round(np.std(nn_norm), 3)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
